import os, cv2, gdal
import numpy as np
from pylab import *

from public import *


class FullBandsImg():
    def __init__(self, band_num=8):
        band_place_holder = ""
        self.imgs_bands = []
        for i in range(band_num):
            self.imgs_bands.append(band_place_holder)

    def get_bands(self, idx):
        return self.imgs_bands[idx - 1]

    def set_bands(self, idx, img):
        self.imgs_bands[idx - 1] = img


class Record():
    def __init__(self, name, x, y, radius):
        self.name = name
        self.class_id = CLASS_DICT.get(name)
        if self.class_id == None:
            self.class_id = 0
        self.x = abs(x)
        self.y = abs(y)
        self.radius = abs(radius)

    def is_in_area(self, xmin, ymin, xmax, ymax):
        '''
        需要处理相切情况
        :param xmin:
        :param ymin:
        :param xmax:
        :param ymax:
        :return:
        '''
        if self.x > (xmin - self.radius) \
                and self.y > (ymin - self.radius) \
                and self.x < (xmax + self.radius) \
                and self.y < (ymax + self.radius):
            return True
        return False

    def crop_save_np(self, out_dir='label_records'):
        np_set = {}
        im_np = global_full_im.get_bands(1)
        np_set = im_np[self.y - self.radius:self.y + self.radius + 1, self.x - self.radius:self.x + self.radius + 1]
        np_set = np.expand_dims(np_set, 2)
        for i in range(7):
            im_np = global_full_im.get_bands(i + 2)
            np_set = np.concatenate(
                (np_set, np.expand_dims(
                    im_np[self.y - self.radius:self.y + self.radius + 1, self.x - self.radius:self.x + self.radius + 1],
                    2)), 2)
        if os.path.exists(out_dir) is False:
            os.makedirs((out_dir))
        np.save(
            os.path.join(out_dir, '{}-{}-{}-{}-{}.npy'.format(self.class_id,self.x, self.y, self.name, self.radius)),np_set)
        # np.load

    def crop_img(self, img_np, img_width, img_height, crop_size=512):
        # COMMENT : calc the area for crop and crop the image , return scope
        (xmin, ymin, xmax, ymax) = (-1, -1, -1, -1)
        if self.x - crop_size < 0:
            xmin = 0
            xmax = crop_size
            assert xmax < img_width
        elif self.x + crop_size > img_width:
            xmax = img_width
            xmin = img_width - crop_size
            assert xmin > 0
        else:
            xmin = self.x - int(crop_size / 2)
            xmax = self.x + int(crop_size / 2)

        if self.y - crop_size < 0:
            ymin = 0
            ymax = crop_size
            assert ymax < img_height
        elif self.y + crop_size > img_height:
            ymax = img_height
            ymin = img_height - crop_size
        else:
            ymin = self.y - int(crop_size / 2)
            ymax = self.y + int(crop_size / 2)

        return img_np[ymin:ymax, xmin:xmax], xmin, ymin, xmax, ymax

    def fill_in_this_record_with_color(self, cropped_img, crop_xmin, crop_ymin):
        cropped_img = cv2.circle(cropped_img, (self.x - crop_xmin, self.y - crop_ymin), self.radius,
                                 DICT2COLOR.get(self.class_id), -1)
        return cropped_img

    def img_color2classkey(self, color_img):
        width = color_img.shape[1]
        height = color_img.shape[0]

        np.zeros((height, width, 1), dtype=uint8)
        #


def case1():
    record = Record("玉米", 5, 5, 2)
    print(record.is_in_area(1, 1, 3, 3))
    print(record.is_in_area(1, 1, 2, 2))


def case2():
    cropped_img = np.zeros((512, 512, 3), dtype=np.uint8)

    record = Record("玉米", 50, 50, 50)
    cropped_img = record.fill_in_this_record_with_color(cropped_img, 0, 0)
    plt.imshow(cropped_img)
    plt.show()
    plt.cla()

    record = Record("大豆", 100, 50, 50)
    cropped_img = record.fill_in_this_record_with_color(cropped_img, 0, 0)
    plt.imshow(cropped_img)
    plt.show()
    plt.cla()

    record = Record("水稻", 100, 100, 50)
    cropped_img = record.fill_in_this_record_with_color(cropped_img, 0, 0)
    plt.imshow(cropped_img)
    plt.show()
    plt.cla()

    cropped_img = image2label(cropped_img)
    plt.imshow(cropped_img)
    plt.show()
    plt.cla()


def read_all_records(file):
    records = []
    with open(file, 'r', encoding='utf-8') as fp:
        fp.readline()  # skip first line .
        lines = fp.readlines()
        for line in lines:
            strs = line.split(',')
            name = strs[2]
            radius = int(strs[3])
            x = int(float(strs[5]))
            y = int(float(strs[6]))
            record = Record(name, x, y, radius)
            records.append(record)
    return records


def read_tif(file_name):
    dataset = gdal.Open(file_name)
    if dataset == None:
        print(file_name + "文件无法打开")
        return
    im_width = dataset.RasterXSize  # 栅格矩阵的列数
    im_height = dataset.RasterYSize  # 栅格矩阵的行数
    im_bands = dataset.RasterCount  # 波段数
    print("{} {} {}".format(im_width, im_height, im_bands))
    global_full_im.set_bands(1, np.array(dataset.GetRasterBand(1).ReadAsArray()))
    global_full_im.set_bands(2, np.array(dataset.GetRasterBand(2).ReadAsArray()))
    global_full_im.set_bands(3, np.array(dataset.GetRasterBand(3).ReadAsArray()))
    global_full_im.set_bands(4, np.array(dataset.GetRasterBand(4).ReadAsArray()))
    global_full_im.set_bands(5, np.array(dataset.GetRasterBand(5).ReadAsArray()))
    global_full_im.set_bands(6, np.array(dataset.GetRasterBand(6).ReadAsArray()))
    global_full_im.set_bands(7, np.array(dataset.GetRasterBand(7).ReadAsArray()))
    global_full_im.set_bands(8, np.array(dataset.GetRasterBand(8).ReadAsArray()))
    return im_width, im_height, im_bands


if __name__ == '__main__':

    # COMMENT : test cases
    # case1()
    # case2()

    # COMMENT : read file and create the instance of crops
    global_full_im = FullBandsImg()
    crops_records = read_all_records('./train_sample.txt')
    global_remote_sense_image_width, global_remote_sense_image_height, global_remote_sense_image_bands = read_tif(
        '/home/leo/Downloads/data_set/remote_sense_im/GF6_WFV_E127.9_N46.8_20180823_L1A1119838015.tif')
    # global_remote_sense_image = read_tif('D:/dataset/satellite_imset/GF6_WFV_E127.9_N46.8_20180823_L1A1119838015.tif')
    # global_remote_sense_image = np.expand_dims(global_remote_sense_image, 0)
    for i, record in enumerate(crops_records):
        sys.stdout.write('\r>> processing label %d/%d ' % (
            i + 1, len(crops_records)))
        sys.stdout.flush()
        record.crop_save_np()
        # croped_im_np, xmin, ymin, xmax, ymax = record.crop_img(global_full_im.get_bands(3),
        #                                                        global_remote_sense_image_width,
        #                                                        global_remote_sense_image_height)
        # plt.imshow(croped_im_np)
        # plt.show()
        # plt.savefig(os.path.join('record','{}-{}.png'.format(record.x,record.y)))
